import torch
from torch import nn
import math
from sshcode.Transformer.Decoder import output

#输出层
class Generator(nn.Module):
    def __init__(self,d_model,vocab_size):
        super().__init__()
        self.linear = nn.Linear(d_model,vocab_size)
    def forward(self,x):
        return torch.softmax(self.linear(x),dim=-1)

# generator = Generator(8,20)
# predict = generator(output)
# print(predict)
# print(predict.shape)
# print(torch.argmax(predict,dim=-1))